{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0",
   "metadata": {},
   "source": [
    "This example requires the following dependencies to be installed:\n",
    "pip install \"lightly[timm]\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install \"lightly[timm]\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2",
   "metadata": {},
   "source": [
    "Note: The model and training settings do not follow the reference settings\n",
    "from the paper. The settings are chosen such that the example can easily be\n",
    "run on a small dataset with a single GPU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "from torch import nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from lightly.models import utils\n",
    "from lightly.models.modules import AIMPredictionHead, MaskedCausalVisionTransformer\n",
    "from lightly.transforms import AIMTransform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5",
   "metadata": {},
   "outputs": [],
   "source": [
    "class AIM(nn.Module):\n",
    "    def __init__(self, vit):\n",
    "        super().__init__()\n",
    "        utils.initialize_2d_sine_cosine_positional_embedding(\n",
    "            pos_embedding=vit.pos_embed, has_class_token=vit.has_class_token\n",
    "        )\n",
    "        self.patch_size = vit.patch_embed.patch_size[0]\n",
    "        self.num_patches = vit.patch_embed.num_patches\n",
    "\n",
    "        self.backbone = vit\n",
    "        self.projection_head = AIMPredictionHead(\n",
    "            input_dim=vit.embed_dim, output_dim=3 * self.patch_size**2, num_blocks=1\n",
    "        )\n",
    "\n",
    "    def forward(self, images):\n",
    "        batch_size = images.shape[0]\n",
    "\n",
    "        mask = utils.random_prefix_mask(\n",
    "            size=(batch_size, self.num_patches),\n",
    "            max_prefix_length=self.num_patches - 1,\n",
    "            device=images.device,\n",
    "        )\n",
    "        features = self.backbone.forward_features(images, mask=mask)\n",
    "        # Add positional embedding before head.\n",
    "        features = self.backbone._pos_embed(features)\n",
    "        predictions = self.projection_head(features)\n",
    "\n",
    "        # Convert images to patches and normalize them.\n",
    "        patches = utils.patchify(images, self.patch_size)\n",
    "        patches = utils.normalize_mean_var(patches, dim=-1)\n",
    "\n",
    "        return predictions, patches"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6",
   "metadata": {},
   "outputs": [],
   "source": [
    "vit = MaskedCausalVisionTransformer(\n",
    "    img_size=224,\n",
    "    patch_size=32,\n",
    "    embed_dim=768,\n",
    "    depth=12,\n",
    "    num_heads=12,\n",
    "    qk_norm=False,\n",
    "    class_token=False,\n",
    "    no_embed_class=True,\n",
    ")\n",
    "model = AIM(vit)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8",
   "metadata": {
    "lines_to_next_cell": 2
   },
   "outputs": [],
   "source": [
    "transform = AIMTransform()\n",
    "# we ignore object detection annotations by setting target_transform to return 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def target_transform(t):\n",
    "    return 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = torchvision.datasets.VOCDetection(\n",
    "    \"datasets/pascal_voc\",\n",
    "    download=True,\n",
    "    transform=transform,\n",
    "    target_transform=target_transform,\n",
    ")\n",
    "# or create a dataset from a folder containing images or videos:\n",
    "# dataset = LightlyDataset(\"path/to/folder\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataloader = torch.utils.data.DataLoader(\n",
    "    dataset,\n",
    "    batch_size=256,\n",
    "    shuffle=True,\n",
    "    drop_last=True,\n",
    "    num_workers=8,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12",
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = nn.MSELoss()\n",
    "optimizer = torch.optim.AdamW(model.parameters(), lr=1.5e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Starting Training\")\n",
    "for epoch in range(10):\n",
    "    total_loss = 0\n",
    "    for batch in dataloader:\n",
    "        views = batch[0]\n",
    "        images = views[0].to(device)  # views contains only a single view\n",
    "        predictions, targets = model(images)\n",
    "        loss = criterion(predictions, targets)\n",
    "        total_loss += loss.detach()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        optimizer.zero_grad()\n",
    "    avg_loss = total_loss / len(dataloader)\n",
    "    print(f\"epoch: {epoch:>02}, loss: {avg_loss:.5f}\")"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "-all",
   "main_language": "python",
   "notebook_metadata_filter": "-all"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
